{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem: Write a Custom Activation Function with Autograd\n",
    "\n",
    "### Problem Statement\n",
    "Implement a **custom activation function**, **Learned-SiLU**, using `torch.autograd.Function`. The activation function should be based on the SiLU formula \\( x \\cdot \\text{sigmoid}(x) \\) but include a **learnable slope parameter**. Use this custom activation function in a simple linear regression model.\n",
    "\n",
    "### Requirements\n",
    "1. **Define the Custom Activation Function**:\n",
    "   - Implement a custom activation function, **Learned-SiLU**, where the output is calculated as:\n",
    "     $$\n",
    "     \\text{Learned-SiLU}(x) = \\text{slope} \\cdot x \\cdot \\text{sigmoid}(x)\n",
    "     $$\n",
    "   - The **slope** should be a learnable parameter.\n",
    "\n",
    "2. **Autograd Implementation**:\n",
    "   - Use `torch.autograd.Function` to define the forward and backward passes for the custom activation function.\n",
    "\n",
    "3. **Integrate the Activation Function**:\n",
    "   - Incorporate the custom activation function into a simple linear regression model.\n",
    "   - Train the model to verify the functionality of the activation function.\n",
    "\n",
    "### Constraints\n",
    "- Ensure the **slope parameter** is properly initialized and updated during training.\n",
    "\n",
    "<details>\n",
    "  <summary>💡 Hint</summary>\n",
    "  Some details: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html\n",
    "</details>\n",
    "\n",
    "\n",
    "<details>\n",
    "  <summary>💡 Alternate Implementation?</summary>\n",
    "  Can be done with nn.Module without implementing backward.\n",
    "</details>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate synthetic data\n",
    "torch.manual_seed(42)\n",
    "X = torch.rand(100, 1) * 10  # 100 data points between 0 and 10\n",
    "y = 2 * X + 3 + torch.randn(100, 1)  # Linear relationship with noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [100/1000], Loss: 1.0552\n",
      "Epoch [200/1000], Loss: 0.8031\n",
      "Epoch [300/1000], Loss: 0.7150\n",
      "Epoch [400/1000], Loss: 0.6826\n",
      "Epoch [500/1000], Loss: 0.6705\n",
      "Epoch [600/1000], Loss: 0.6659\n",
      "Epoch [700/1000], Loss: 0.6642\n",
      "Epoch [800/1000], Loss: 0.6635\n",
      "Epoch [900/1000], Loss: 0.6632\n",
      "Epoch [1000/1000], Loss: 0.6632\n"
     ]
    }
   ],
   "source": [
    "class LearnedSiLUFunction(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x, slope):\n",
    "        # Save the input tensor and slope for backward computation\n",
    "        ctx.save_for_backward(x)\n",
    "        ctx.slope = slope\n",
    "        return slope * x * torch.sigmoid(x)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        # Retrieve the input and slope saved in the forward pass\n",
    "        x, = ctx.saved_tensors\n",
    "        slope = ctx.slope\n",
    "        sigmoid_x = torch.sigmoid(x)\n",
    "\n",
    "        # Compute the gradient with respect to input (x)\n",
    "        grad_input = grad_output * slope * (sigmoid_x + x * sigmoid_x * (1 - sigmoid_x))\n",
    "\n",
    "        # Compute the gradient with respect to slope\n",
    "        grad_slope = grad_output * x * sigmoid_x\n",
    "\n",
    "        return grad_input, grad_slope\n",
    "\n",
    "\n",
    "# Define the Linear Regression Model\n",
    "class LinearRegressionModel(nn.Module):\n",
    "    def __init__(self, slope=1):\n",
    "        super().__init__()\n",
    "        self.slope = nn.Parameter(torch.ones(1) * slope)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Use the custom LearnedSiLUFunction\n",
    "        return LearnedSiLUFunction.apply(x, self.slope)\n",
    "\n",
    "# Initialize the model, loss function, and optimizer\n",
    "model = LinearRegressionModel()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
    "\n",
    "# Training loop\n",
    "epochs = 1000\n",
    "for epoch in range(epochs):\n",
    "    # Forward pass\n",
    "    predictions = model(X)\n",
    "    loss = criterion(predictions, y)\n",
    "\n",
    "    # Backward pass and optimization\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    # Log progress every 100 epochs\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        print(f\"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Learned weight: 1.9557, Learned bias: 2.2181\n",
      "Predictions for [[4.0], [7.0]]: [[11.04088020324707], [16.907970428466797]]\n"
     ]
    }
   ],
   "source": [
    "# Display the learned parameters\n",
    "[w, b] = model.linear.parameters()\n",
    "print(f\"Learned weight: {w.item():.4f}, Learned bias: {b.item():.4f}\")\n",
    "\n",
    "# Testing on new data\n",
    "X_test = torch.tensor([[4.0], [7.0]])\n",
    "with torch.no_grad():\n",
    "    predictions = model(X_test)\n",
    "    print(f\"Predictions for {X_test.tolist()}: {predictions.tolist()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
